#!/usr/bin/env python
# coding: utf-8

# In[1]:


import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
import pandas as pd
from jm.library.data_helper import filter_data, date_to_str, is_weakly_less_than
from jm.library.likelihood import Likelihood
import matplotlib.ticker as mtick
from jm.library.additional_helpers import get_cols_over_time, plot_over_time, payment_cols, priority_cols, action_cols, \
    relative_payment_cols, fwdiff_payment_cols, relative_fwdiff_payment_cols, action_nomedida_cols

rcParams['figure.figsize'] = (13.0, 7.0)
rcParams['lines.linewidth'] = 3
rcParams['font.size'] = 18

np.random.seed(1234)

def to_percent(y_value, _):
    return f"{100 * y_value:.1f}%"


rcParams['figure.figsize'] = (13.0, 7.0)
rcParams['lines.linewidth'] = 3
rcParams['font.size'] = 18


# In[4]:

df_status = pd.read_csv('data/jesus_maria_status_voluntary_deidentified.csv')


df_status[relative_payment_cols] = df_status[payment_cols].divide(
                df_status['total_due'], axis=0)


df_status[relative_fwdiff_payment_cols] = df_status[fwdiff_payment_cols].divide(
    df_status['total_due'], axis=0)



plt.clf()
ax = plot_over_time(filter_data(df_status, assignment_to_treatment=1), normalize=1e6)
plot_over_time(filter_data(df_status, assignment_to_treatment=0), normalize=1e6, ax=ax, linestyle='dashed')
plt.legend(['treatment', 'control'])
plt.ylabel('cumulative taxes collected (Millions of Soles)')

plt.tight_layout()
plt.savefig('figs/figOA3.pdf')

# In[14]:


p = df_status.groupby('assignment_to_treatment')['payments_by_2021-09-06'].sum()
with open("numbers/secOA3.txt", "w") as text_file:
    print(p[1]/p[0]-1, file=text_file)


def is_payment_above_thresh_data(df, thresh=.5, ax=None):
    this_df = df[relative_payment_cols] > thresh
    this_df.columns = Likelihood.event_dates
    return this_df


def is_payment_above_thresh(df, thresh=.5, ax=None, linestyle='solid'):
    this_df = is_payment_above_thresh_data(df, thresh, ax)
    ax = this_df.mean(axis=0).plot(linestyle=linestyle)
    plt.ylabel('share paid/due > 50\%')
    return ax


def generate_sample(df, covariate, sample_size, num_quantiles, replace=False):
    np.random.seed(0)
    batch1 = df['batch1_top'] + df['batch1_bottom'] > 0

    quantiles = df.loc[batch1, covariate].quantile([i/num_quantiles for i in range(num_quantiles+1)])
    quantiles.iloc[-1] += 1e-6

    this_df_control = df.loc[df.assignment_to_treatment==0, :]

    indices = []
    for i in range(num_quantiles):
        q_prev, q_next = quantiles[i/num_quantiles], quantiles[(i+1)/num_quantiles]
        is_valid = this_df_control[covariate].between(q_prev, q_next, inclusive='left')
        valid_sample = list(this_df_control.loc[is_valid].index)
        indices += list(np.random.choice(valid_sample, sample_size, replace=replace))

    selected_sample = list(set(list(df.loc[batch1].index) + indices))
    return selected_sample


selected_sample = generate_sample(df_status, 'score_endo_covariates', 60, 5)
df_status_b1 = df_status.loc[selected_sample, :]
# priority vs no priority batch 1

plt.clf()

ax = is_payment_above_thresh(filter_data(df_status_b1, assignment_to_treatment=1), thresh=.5)
is_payment_above_thresh(filter_data(df_status_b1, assignment_to_treatment=0), thresh=.5,
                        ax = ax, linestyle='dashed')
ax.set_ylim([-0.025, 0.8])
plt.legend(['priority G1', 'comparable control'], loc='upper left')
ax.yaxis.set_major_formatter(mtick.FuncFormatter(to_percent))
plt.ylabel('Share of Tax-payers with \n Payments / Tax-Due > .5', fontsize=16)

plt.tight_layout()
plt.savefig('figs/figOA5.pdf')

# In[47]:


quantiles = [i/100. for i in range(101)]
quantiles_due = df_status['total_due'].quantile(quantiles)
quantiles_due = list(quantiles_due)


total_collected = df_status.groupby('assignment_to_treatment')['payments_by_2021-09-06'].sum()


share_collected_by_quantile = pd.DataFrame(index=quantiles, columns=[0, 1])
for q, d in zip(quantiles, quantiles_due):
    this_df = filter_data(df_status, total_due=is_weakly_less_than(d))
    share_collected_by_quantile.loc[q] = this_df.groupby(
        'assignment_to_treatment')['payments_by_2021-09-06'].sum()/total_collected

# In[50]:

plt.clf()
ax = share_collected_by_quantile[[1]].plot(figsize=(13, 7))
share_collected_by_quantile[[0]].plot(figsize=(13, 7), ax=ax, linestyle="dashed")
plt.xlabel('quantile of tax-due')
plt.ylabel('share of taxes collected') 
plt.legend(['treatment', 'control'])
plt.tight_layout()
plt.savefig('figs/figOA4.pdf')
